# rm(list = ls());
gc()

# source("00_pckgs.R")

dir.create("results/us/", F, T)
# full_sessions <- 79:115
# does everything converge first time round?
# otherwise, rerun after having set another value, grab estimates from first run as initial values fo second run
# run <- "first"

# run without nonsep first and with nonsep a second time
# toggle_nonsep <- 0

options(mc.cores = parallel::detectCores())
# rstan_options(auto_write = TRUE)
# set.seed(1917)
# IRT Models ####
if (toggle_nonsep == 0){
  if (!file.exists("results/us/sirt.txt")){
    file.create("results/us/sirt.txt")
    sessions <- 79:115
    run <- "first"
  } else {
    done <- read.table("results/us/sirt.txt") %>% 
      pull(V2)
    if (length(done) < 37){
      sessions <- c(79:115)[c((length(done)+1):37)]
      run <- "first"
    } else {
      sessions <- read.table("results/us/sirt.txt") %>%
        group_by(V2) %>%
        summarize(V5 = min(V5)) %>%
        filter(V5 >= 1.1) %>%
        pull(V2)
      run <- "not-first"
    }
  }
} else {
  if (!file.exists("results/us/nsirt.txt")){
    file.create("results/us/nsirt.txt")
    sessions <- 79:115
    run <- "first"
  } else {
    done <- read.table("results/us/nsirt.txt") %>% 
      pull(V2)
    if (length(done) < 37){
      sessions <- c(79:115)[c((length(done)+1):37)]
      run <- "first"
    } else {
      sessions <- read.table("results/us/nsirt.txt") %>%
        group_by(V2) %>%
        summarize(V5 = min(V5)) %>%
        filter(V5 >= 1.1) %>%
        pull(V2)
      run <- "not-first"
    }
  }
}

# # IRT Models ####
# if (run == "first"){
#   if (toggle_nonsep == 0){
#     file.create("results/us/sirt.txt")
#   } else {
#     file.create("results/us/nsirt.txt")
#   }
#   sessions <- full_sessions
# } else {
#   if (toggle_nonsep == 0){
#     cirt_results <- read.table("results/us/sirt.txt")
#   } else {
#     cirt_results <- read.table("results/us/nsirt.txt")
#   }
#   sessions <- cirt_results %>% 
#     group_by(V2) %>% 
#     summarize(V5 = min(V5)) %>% 
#     filter(V5 >= 1.1) %>% 
#     pull(V2)
# }

for (s in sessions){
  
  if (run == "first"){
    dir.create(paste0("results/us/", s), F, T)
  } else {
    if (toggle_nonsep == 0){
      load(paste0("results/us/", s, "/sirt.Rda"))
    } else {
      load(paste0("results/us/", s, "/nsirt.Rda"))
    }
    
    fit_old <- fit_sum
    th_new <- matrix(fit_sum$mean[fit_sum$parameter == "theta"], ncol = 2)
    th_new[,1] <- resid(lm(th_new[,1] ~ mem_desc$party + mem_desc$party_shift))
    th_new[,2] <- resid(lm(th_new[,2] ~ mem_desc$confed + mem_desc$confed_shift))
    initN <- function() {
      list(
        gamma = fit_old$mean[fit_old$parameter == "gamma_post"][3:4],
        L = chol(matrix(fit_old$mean[fit_old$parameter == "DL"], ncol = 2)),
        theta_free = apply(th_new, 2, function(x) x / sd(x)),
        sal = fit_old$mean[fit_old$parameter == "sal"],
        bk_free = matrix(fit_old$mean[fit_old$parameter == "bk"], ncol = 2),
        ak = fit_old$mean[fit_old$parameter == "ak"]
      )
    }
    rm(fit_sum, mem_desc)
  }
  
  if (s == 79){
    load(paste0("data/us/", s, "/raw.Rda"))
    
    id <- 1:length(unique(row.names(rc)))
    names(id) <- as.character(unique(row.names(rc)))
    mem_desc$id <- id[as.character(mem_desc$icpsr)]
    
    mem_desc <- mem_desc %>% 
      filter(!(is.na(id))) %>% 
      arrange(id) %>% 
      mutate(party = case_when(party_code == 200 ~ 1,
                               party_code == 100 ~ -1,
                               TRUE ~ 0),
             confed = case_when(state_abbrev %in% c("SC", "MS", "FL", "AL", "GA", "LA", "TX", "VA", "AR", "TN", "NC") ~ 1,
                                state_abbrev %in% c("CA", "OR", "WA", "ME", "NH", "MA", "NY", "NJ", "CT") ~ -1,
                                TRUE ~ 0)) %>% 
      mutate(across(c("party", "confed", "nominate.dim1", "nominate.dim2"), ~(.x - mean(.x, na.rm = T)) / sd(.x, na.rm = T)))
    
    pr_init <- matrix(NA, nrow = ncol(rc), ncol = 3)
    pr_init[,1] <- t(apply(rc, 2, function(x) coef(summary(lm(x ~ 1)))[1,1]))-.5
    pr_init[,2] <- t(apply(rc, 2, function(x) coef(summary(lm(x ~ mem_desc$nominate.dim1)))[2,1]))
    pr_init[,3] <- t(apply(rc, 2, function(x) coef(summary(lm(x ~ mem_desc$nominate.dim2)))[2,1]))
    
    initF <- function() {
      list(
        gamma = rep(2,2),
        L = chol(matrix(c(1,0,0,1), ncol = 2)),
        theta_free = matrix(0, ncol = 2, nrow = nrow(rc)),
        bk_free = apply(pr_init[,2:3], 2, function(x) x / sd(x)) + rnorm(ncol(rc)*2, 0, .25),
        ak = pr_init[,1] / sd(pr_init[,1]) + rnorm(ncol(rc), 0, .25)
      )
    }
    
    theta_in <- abind(as.matrix(cbind(mem_desc[,c("party")], 0), ncol = 2), 
                      as.matrix(cbind(mem_desc[,c("confed")], 0), ncol = 2),
                      along = 3)
    
    rc[is.na(rc)] <- -1
    
    rc_c <- matrix(rep(1:ncol(rc), each = nrow(rc)), nrow(rc), ncol(rc))
    rc_r <- matrix(rep(1:nrow(rc), each = ncol(rc)), nrow(rc), ncol(rc), byrow = T)
    
    stan_data <- list(
      I = ncol(rc), # proposal
      J = nrow(rc), # legislator
      N = sum(rc != -1),
      ii = rc_c[rc != -1],
      jj = rc_r[rc != -1],
      J_old = 2,
      Y = rc, # yea
      incl_nonsep = toggle_nonsep,
      K = 2,
      T = 0,
      theta_in = theta_in,
      theta_old = matrix(c(1, 2), ncol = 2, nrow = 2),
      theta_old_ind = c(1, 2),
      D = 2)
    
  } else {
    
    if (toggle_nonsep == 0){
      load(paste0("results/us/", s-1, "/sirt.Rda"))
    } else {
      load(paste0("results/us/", s-1, "/nsirt.Rda"))
    }
    
    mem_desc$dim2_o <- mem_desc$dim1_o <- NULL
    mem_old <- mem_desc %>% rename(dim1_o = dim1_n, dim2_o = dim2_n)
    rm(fit_sum, mem_desc, loo_sum)
    
    load(paste0("data/us/", s, "/raw.Rda"))
    
    id <- 1:length(unique(row.names(rc)))
    names(id) <- as.character(unique(row.names(rc)))
    mem_desc$id <- id[as.character(mem_desc$icpsr)]
    
    mem_desc <- mem_desc %>% 
      filter(!(is.na(id))) %>% 
      arrange(id) %>% 
      mutate(party = case_when(party_code == 200 ~ 1,
                               party_code == 100 ~ -1,
                               TRUE ~ 0),
             confed = case_when(state_abbrev %in% c("SC", "MS", "FL", "AL", "GA", "LA", "TX", "VA", "AR", "TN", "NC") ~ 1,
                                state_abbrev %in% c("CA", "OR", "WA", "ME", "NH", "MA", "NY", "NJ", "CT") ~ -1,
                                TRUE ~ 0)) %>% 
      merge(mem_old %>% select(icpsr, dim1_o, dim2_o), all.x = TRUE, sort = FALSE) %>% 
      mutate(new = ifelse(is.na(dim1_o), 1, 0)) %>% 
      mutate(across(c("party", "confed", "dim1_o", "dim2_o"), ~(.x-mean(.x, na.rm = T))/sd(.x, na.rm = T))) %>% 
      group_by(party) %>% 
      mutate(party_shift = dim1_o - mean(dim1_o, na.rm = T)) %>% 
      ungroup() %>% 
      group_by(confed) %>% 
      mutate(confed_shift = dim2_o - mean(dim2_o, na.rm = T)) %>% 
      ungroup() %>% 
      arrange(id)
    
    gamma_init <- rbind(coef(summary(lm(mem_desc$dim1_o ~ mem_desc$party)))[,1],
                        coef(summary(lm(mem_desc$dim2_o ~ mem_desc$confed)))[,1])
    
    pr_init <- matrix(NA, nrow = ncol(rc), ncol = 3)
    pr_init[,1] <- t(apply(rc, 2, function(x) coef(summary(lm(x ~ 1)))[1,1])) - .5
    pr_init[,2] <- t(apply(rc, 2, function(x) coef(summary(lm(x ~ mem_desc$dim1_o)))[2,1]))
    pr_init[,3] <- t(apply(rc, 2, function(x) coef(summary(lm(x ~ mem_desc$dim2_o)))[2,1]))
    
    mem_desc$dim1_o[mem_desc$new == 1] <- rnorm(sum(mem_desc$new), gamma_init[1,1] + gamma_init[1,2] * mem_desc$party[mem_desc$new == 1], .25)
    mem_desc$dim2_o[mem_desc$new == 1] <- rnorm(sum(mem_desc$new), gamma_init[2,1] + gamma_init[2,2] * mem_desc$confed[mem_desc$new == 1], .25)
    
    initF <- function() {
      list(
        gamma = rep(2, 2),
        L = chol(matrix(c(1, 0, 0, 1), ncol = 2)),
        theta_free = matrix(0, ncol = 2, nrow = nrow(rc)),
        bk_free = apply(pr_init[,2:3], 2, function(x) x / sd(x)) + rnorm(ncol(rc) * 2, 0, .25),
        ak = pr_init[,1] / sd(pr_init[,1]) + rnorm(ncol(rc), 0, .25)
      )
    }
    
    mem_desc <- mem_desc %>% 
      mutate(across(c("party_shift", "confed_shift"), ~ (.x / sd(.x, na.rm = T)))) %>% 
      group_by(party) %>% 
      mutate(party_shift = ifelse(new == 1, 0, party_shift)) %>%
      ungroup() %>% 
      group_by(confed) %>% 
      mutate(confed_shift = ifelse(new == 1, 0, confed_shift)) %>%
      ungroup() %>% 
      arrange(id)
    
    theta_in <- abind(as.matrix(mem_desc[,c("party", "party_shift")], ncol = 2),
                      as.matrix(mem_desc[,c("confed", "confed_shift")], ncol = 2),
                      along = 3)
    
    rc[is.na(rc)] <- -1
    
    rc_c <- matrix(rep(1:ncol(rc), each = nrow(rc)), nrow(rc), ncol(rc))
    rc_r <- matrix(rep(1:nrow(rc), each = ncol(rc)), nrow(rc), ncol(rc), byrow = T)
    
    stan_data <- list(
      I = ncol(rc), # proposal
      J = nrow(rc), # legislator
      N = sum(rc != -1),
      ii = rc_c[rc != -1],
      jj = rc_r[rc != -1],
      J_old = length(mem_desc$id[mem_desc$new != 1]),
      Y = rc, # yea
      incl_nonsep = toggle_nonsep,
      K = 2,
      T = 1,
      theta_in = theta_in,
      theta_old = mem_desc[mem_desc$new != 1, c("dim1_o", "dim2_o")],
      theta_old_ind = mem_desc$id[mem_desc$new != 1],
      D = 2)
  }
  
  mod <- cmdstan_model("mod_dyn.stan")
  
  if (run == "first"){
    if (s %in% c(97, 103, 104)){
      fit <- mod$sample(data = stan_data, init = initF, iter_warmup = 500, iter_sampling = 750, chains = 8, seed = 2021)
    } else {
      fit <- mod$sample(data = stan_data, init = initF, iter_warmup = 500, iter_sampling = 500, chains = 8, seed = 2021)
    }
  } else {
    fit <- mod$sample(data = stan_data, init = initN, iter_warmup = 500, iter_sampling = 1000, chains = 8, seed = 2021+i)
  }
  end <- Sys.time()
  
  # extract results
  fit_sum <- (fit$summary(c("gamma_post","ak", "sal", "weights", "bk", "theta", "r_old", "r_dim", "r_sq", "m_prop_compl", "sd_prop_compl", "p_share", "DL")))
  fit_sum$parameter <- as.factor(gsub("\\[.*]", "", fit_sum$variable))
  r <- max(fit_sum$rhat[!(is.na(fit_sum$rhat)) & !(is.nan(fit_sum$rhat)) & fit_sum$parameter %in% c("ak", "weights", "bk", "theta")])
  
  gc();Sys.sleep(15);gc()
  
  loo_sum <- fit$loo()
  
  tmp <- matrix(fit_sum$mean[fit_sum$parameter == "theta"], ncol = 2)
  
  mem_desc <- mem_desc %>% mutate(dim1_n = tmp[,1], dim2_n = tmp[,2])
  if (toggle_nonsep == 0){
    save(fit_sum, loo_sum, mem_desc,
         file = paste0("results/us/", s, "/sirt.Rda"))
  } else {
    save(fit_sum, loo_sum, mem_desc,
         file = paste0("results/us/", s, "/nsirt.Rda"))
  }
  rm(fit)
  gc();Sys.sleep(15);gc()
  if (toggle_nonsep == 0){
    write(paste0("session ", s, " ", end, " ", round(r, 3)), file = "results/us/sirt.txt", append = TRUE)
  } else {
    write(paste0("session ", s, " ", end, " ", round(r, 3)), file = "results/us/nsirt.txt", append = TRUE)
  }
}
